#!/bin/bash

export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_SOCKET_IFNAME=bond0
export NCCL_IB_DISABLE=0
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_IB_GID_INDEX=0
export NCCL_IB_HCA=mlx5_0,mlx5_1
export NCCL_DEBUG=DEBUG
export OMP_NUM_THREADS=4
export CUDA_DEVICE_MAX_CONNECTIONS=1
export GLOO_SOCKET_IFNAME=bond0
set -u
  PROJ_HOME=$1
  EXPNAME=$2
  LOAD_EXPNAME=$3
set +u

# Datasets
source $(dirname $0)/datasets.sh
DATA_PATH=$DATASET_DX_KNOWLEDGE6

HOSTFILE=$PROJ_HOME/config/hostfile
echo $HOSTFILE

CHECKPOINT_PATH=$PROJ_HOME/checkpoints/$EXPNAME
LOAD_CHECKPOINT_PATH=$PROJ_HOME/checkpoints/$LOAD_EXPNAME
mkdir -p $CHECKPOINT_PATH
VOCAB_FILE=examples/aquila/tokenizer/vocab.json
MERGE_FILE=examples/aquila/tokenizer/merges.txt
LOG_PATH=$PROJ_HOME/logs/$EXPNAME
mkdir -p $LOG_PATH
cp $0 $LOG_PATH/
TB_PATH=$PROJ_HOME/tboard/$EXPNAME
mkdir -p $TB_PATH
WB_PATH=$PROJ_HOME/wandb/$EXPNAME
mkdir -p $WB_PATH

# Change for multinode config
export NODE_ADDR=$(ifconfig bond0|grep inet|grep -v 127.0.0.1|grep -v inet6|awk '{print $2;}'|tr -d "addr:"|head -n 1)
export GPUS_PER_NODE=$(awk '{$1=$1;print}' $HOSTFILE|awk -F" |=" '{ranks[$1]=$NF;}END{print ranks["'$NODE_ADDR'"];}')
export NNODES=$(cat $HOSTFILE | wc -l)
export MASTER_ADDR=$(head -n1 $HOSTFILE | awk '{print $1;}')
export NODE_RANK=$(awk '{ranks[$1]=(FNR-1);}END{print ranks["'$NODE_ADDR'"];}' $HOSTFILE)
export MASTER_PORT=12345
WORLD_SIZE=$(($GPUS_PER_NODE * $NNODES))

DISTRIBUTED_ARGS="
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    --node_rank $NODE_RANK \
    --master_addr $MASTER_ADDR \
    --master_port $MASTER_PORT 
"

TRAINING_ARGS="
    --train-samples 488281250 \
    --rampup-batch-size 16 16 3906250 \
    --eval-iters 0 \
    --eval-interval 2000 \
    --tensor-model-parallel-size 4 \
    --pipeline-model-parallel-size 8 \
    --micro-batch-size 1 \
    --global-batch-size 1024 \
    --disable-bias-linear \
    --use-flash-attn \
    --sequence-parallel \
    --recompute-granularity 'full' \
    --recompute-method 'uniform' \
    --recompute-num-layers 1 \
    --use-distributed-optimizer
"

MIXED_PRECISION_ARGS="
    --bf16 \
    --initial-loss-scale 65536 \
    --min-loss-scale 1.0 \
    --loss-scale-window 1024 \
    --attention-softmax-in-fp32 \
    --embedding-weights-in-fp32 \
    --accumulate-allreduce-grads-in-fp32
"

DATA_ARGS="
    --data-path $DATA_PATH \
    --tokenizer-type AquilaTokenizer \
    --vocab-file $VOCAB_FILE \
    --vocab-size 100008\
    --merge-file $MERGE_FILE \
    --data-impl mmap \
    --split 1
"

NETWORK_ARGS="
    --num-layers 80 \
    --hidden-size 8192 \
    --num-attention-heads 64 \
    --num-kv-attention-heads 8 \
    --hidden-dim-multiplier 1.3 \
    --seq-length 4096 \
    --max-position-embeddings 4096 \
    --layernorm-epsilon 1e-5 \
    --use-rotary-position-embeddings \
    --rotary-position-embeddings-in-fp32 \
    --no-position-embedding \
    --swiglu \
    --multiple-of 4096 \
    --apply-layernorm-rms \
    --make-vocab-size-divisible-by 64 \
    --untie-embeddings-and-output-weights
"

INITIALIZATION_ARGS="
    --apply-init-customized \
    --init-method-std-scaled-embed 0.0123291015625 \
    --init-method-std-scaled-attn-q 0.01025390625 0.01373291015625 0.0174560546875 0.0162353515625 0.01531982421875 0.01611328125 0.0166015625 0.016357421875 0.0164794921875 0.0150146484375 0.01409912109375 0.01470947265625 0.0155029296875 0.01519775390625 0.0155029296875 0.01422119140625 0.01434326171875 0.015625 0.01446533203125 0.0142822265625 0.01409912109375 0.01336669921875 0.0130615234375 0.013916015625 0.0140380859375 0.01483154296875 0.01446533203125 0.012939453125 0.013427734375 0.01312255859375 0.01123046875 0.0130615234375 0.01348876953125 0.0145263671875 0.01336669921875 0.01324462890625 0.01361083984375 0.01409912109375 0.01446533203125 0.0152587890625 0.0150146484375 0.014404296875 0.01409912109375 0.013671875 0.01470947265625 0.0128173828125 0.01251220703125 0.013916015625 0.0118408203125 0.0107421875 0.01116943359375 0.0118408203125 0.01300048828125 0.01055908203125 0.01080322265625 0.0120849609375 0.01141357421875 0.010986328125 0.010498046875 0.01043701171875 0.0084228515625 0.0098876953125 0.01007080078125 0.00982666015625 0.0108642578125 0.00885009765625 0.010498046875 0.00927734375 0.01116943359375 0.01226806640625 0.011962890625 0.0125732421875 0.01312255859375 0.01251220703125 0.01348876953125 0.0128173828125 0.0140380859375 0.01324462890625 0.01287841796875 0.012451171875 \
    --init-method-std-scaled-attn-k 0.0230712890625 0.0257568359375 0.030517578125 0.0267333984375 0.023193359375 0.024169921875 0.024169921875 0.023681640625 0.025634765625 0.0233154296875 0.021484375 0.0228271484375 0.0242919921875 0.0234375 0.022705078125 0.021240234375 0.0223388671875 0.0242919921875 0.02197265625 0.022216796875 0.0224609375 0.021728515625 0.0220947265625 0.0225830078125 0.0224609375 0.024658203125 0.0228271484375 0.021728515625 0.021240234375 0.0206298828125 0.0185546875 0.0206298828125 0.0213623046875 0.021728515625 0.022216796875 0.0205078125 0.0211181640625 0.021484375 0.021484375 0.02197265625 0.021484375 0.021728515625 0.020751953125 0.0208740234375 0.0206298828125 0.019287109375 0.0189208984375 0.0216064453125 0.01806640625 0.017578125 0.016845703125 0.0179443359375 0.0184326171875 0.0162353515625 0.0159912109375 0.018310546875 0.0181884765625 0.0166015625 0.0166015625 0.016357421875 0.0142822265625 0.0159912109375 0.0162353515625 0.015869140625 0.017822265625 0.01519775390625 0.0166015625 0.016845703125 0.017578125 0.0191650390625 0.0184326171875 0.019287109375 0.020263671875 0.0191650390625 0.019775390625 0.0186767578125 0.01953125 0.01904296875 0.01904296875 0.01953125 \
    --init-method-std-scaled-attn-v 0.00860595703125 0.0084228515625 0.0091552734375 0.01068115234375 0.01031494140625 0.01055908203125 0.01068115234375 0.0103759765625 0.01055908203125 0.01171875 0.0125732421875 0.0118408203125 0.0123291015625 0.01171875 0.01171875 0.01263427734375 0.0125732421875 0.01220703125 0.0133056640625 0.013671875 0.0128173828125 0.0126953125 0.0140380859375 0.013671875 0.01263427734375 0.01220703125 0.01312255859375 0.0140380859375 0.01446533203125 0.01416015625 0.0145263671875 0.01397705078125 0.01397705078125 0.01361083984375 0.013916015625 0.0137939453125 0.01397705078125 0.01416015625 0.01361083984375 0.013916015625 0.0147705078125 0.01434326171875 0.01483154296875 0.01470947265625 0.01531982421875 0.015869140625 0.015869140625 0.0152587890625 0.01611328125 0.015869140625 0.01611328125 0.0164794921875 0.0162353515625 0.0167236328125 0.0166015625 0.0172119140625 0.0166015625 0.0172119140625 0.0164794921875 0.0167236328125 0.0167236328125 0.01708984375 0.016845703125 0.0164794921875 0.017333984375 0.0162353515625 0.0166015625 0.0157470703125 0.0181884765625 0.0191650390625 0.017822265625 0.0177001953125 0.017578125 0.0186767578125 0.017578125 0.0191650390625 0.020263671875 0.0194091796875 0.017822265625 0.0146484375 \
    --init-method-std-scaled-ffn-w1 0.00958251953125 0.01287841796875 0.01409912109375 0.014404296875 0.0145263671875 0.01458740234375 0.01470947265625 0.0147705078125 0.01507568359375 0.01513671875 0.0152587890625 0.01519775390625 0.01507568359375 0.01495361328125 0.0150146484375 0.0150146484375 0.01495361328125 0.01495361328125 0.0150146484375 0.0150146484375 0.0150146484375 0.0150146484375 0.01495361328125 0.01495361328125 0.014892578125 0.0147705078125 0.0147705078125 0.0147705078125 0.01483154296875 0.01483154296875 0.01483154296875 0.014892578125 0.01483154296875 0.01483154296875 0.0147705078125 0.01483154296875 0.01483154296875 0.01483154296875 0.01483154296875 0.01483154296875 0.01483154296875 0.0150146484375 0.01519775390625 0.0152587890625 0.0155029296875 0.015625 0.0157470703125 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.015869140625 0.0157470703125 0.0157470703125 0.0157470703125 0.0157470703125 0.0157470703125 0.0157470703125 0.015625 0.015625 0.015625 0.01556396484375 0.01556396484375 0.0155029296875 0.0155029296875 0.0155029296875 0.01544189453125 0.01544189453125 0.01556396484375 0.015625 0.015869140625 0.01611328125 \
    --init-method-std-scaled-ffn-w2 0.011474609375 0.0135498046875 0.0140380859375 0.01416015625 0.01416015625 0.01416015625 0.01422119140625 0.0142822265625 0.01434326171875 0.01434326171875 0.01434326171875 0.01434326171875 0.01434326171875 0.014404296875 0.01434326171875 0.014404296875 0.01434326171875 0.014404296875 0.014404296875 0.01446533203125 0.01446533203125 0.0145263671875 0.01446533203125 0.01446533203125 0.01446533203125 0.0145263671875 0.01446533203125 0.01446533203125 0.01446533203125 0.01446533203125 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.01446533203125 0.0145263671875 0.01446533203125 0.01446533203125 0.01446533203125 0.014404296875 0.01434326171875 0.01434326171875 0.0142822265625 0.0142822265625 0.0142822265625 0.0142822265625 0.01434326171875 0.01434326171875 0.01434326171875 0.01434326171875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.01446533203125 0.01446533203125 0.01446533203125 0.01446533203125 0.01446533203125 0.0145263671875 0.0145263671875 0.01458740234375 0.0146484375 0.0146484375 0.0146484375 0.01470947265625 0.01470947265625 0.01470947265625 0.01470947265625 0.01470947265625 0.0147705078125 0.0147705078125 0.01483154296875 \
    --init-method-std-scaled-ffn-w3 0.00933837890625 0.0125732421875 0.01348876953125 0.0137939453125 0.013916015625 0.01397705078125 0.0140380859375 0.01409912109375 0.01416015625 0.01422119140625 0.0142822265625 0.01434326171875 0.01434326171875 0.0142822265625 0.01434326171875 0.01434326171875 0.01434326171875 0.0142822265625 0.0142822265625 0.01434326171875 0.014404296875 0.01446533203125 0.01446533203125 0.01446533203125 0.01446533203125 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.014404296875 0.01434326171875 0.01434326171875 0.01434326171875 0.01434326171875 0.01434326171875 0.014404296875 0.014404296875 0.0142822265625 0.01422119140625 0.01422119140625 0.01409912109375 0.01409912109375 0.01409912109375 0.0140380859375 0.0140380859375 0.01409912109375 0.01409912109375 0.01416015625 0.01416015625 0.01416015625 0.01422119140625 0.0142822265625 0.0142822265625 0.01434326171875 0.01434326171875 0.01434326171875 0.014404296875 0.014404296875 0.01446533203125 0.01446533203125 0.0145263671875 0.0145263671875 0.01458740234375 0.0146484375 0.0146484375 0.01470947265625 0.0147705078125 0.0147705078125 0.0147705078125 0.01483154296875 0.01483154296875 0.01483154296875 0.014892578125 0.01495361328125 0.01507568359375 0.01544189453125 \
    --init-method-std-scaled-output 0.01513671875 \
    --init-weight-attn-norm 0.00823974609375 0.0277099609375 0.0361328125 0.0927734375 0.1279296875 0.1396484375 0.140625 0.171875 0.158203125 0.1787109375 0.1572265625 0.18359375 0.154296875 0.1875 0.2177734375 0.216796875 0.2080078125 0.20703125 0.2275390625 0.2392578125 0.2353515625 0.21875 0.2099609375 0.2216796875 0.20703125 0.2158203125 0.248046875 0.212890625 0.251953125 0.265625 0.24609375 0.267578125 0.255859375 0.27734375 0.208984375 0.25390625 0.24609375 0.27734375 0.2890625 0.29296875 0.291015625 0.28125 0.302734375 0.26953125 0.326171875 0.296875 0.28515625 0.2890625 0.275390625 0.248046875 0.279296875 0.28515625 0.322265625 0.265625 0.28125 0.294921875 0.271484375 0.287109375 0.25390625 0.26953125 0.1982421875 0.248046875 0.265625 0.25 0.271484375 0.224609375 0.2421875 0.189453125 0.3203125 0.349609375 0.29296875 0.302734375 0.322265625 0.328125 0.30859375 0.314453125 0.34765625 0.3125 0.322265625 0.283203125 \
    --init-weight-ffn-norm 0.02490234375 0.049560546875 0.0693359375 0.0830078125 0.09326171875 0.09716796875 0.10498046875 0.11181640625 0.1259765625 0.13671875 0.1474609375 0.1513671875 0.1513671875 0.154296875 0.162109375 0.1669921875 0.1708984375 0.171875 0.1796875 0.1865234375 0.193359375 0.197265625 0.2021484375 0.203125 0.2001953125 0.1962890625 0.2060546875 0.2109375 0.21484375 0.2197265625 0.224609375 0.228515625 0.2314453125 0.2333984375 0.23828125 0.244140625 0.248046875 0.251953125 0.255859375 0.2578125 0.263671875 0.26953125 0.27734375 0.28125 0.291015625 0.30078125 0.30859375 0.3125 0.318359375 0.32421875 0.330078125 0.33203125 0.337890625 0.341796875 0.345703125 0.34765625 0.353515625 0.35546875 0.361328125 0.365234375 0.369140625 0.37109375 0.375 0.380859375 0.384765625 0.388671875 0.392578125 0.396484375 0.40234375 0.404296875 0.41015625 0.41796875 0.423828125 0.4296875 0.435546875 0.439453125 0.439453125 0.431640625 0.400390625 0.32421875 \
    --init-weight-output-norm 1.1171875 \
    --init-method-std 0.0064 \
    --seed 42
"

REGULARIZATION_ARGS="
    --attention-dropout 0.0 \
    --hidden-dropout 0.0 \
    --weight-decay 0.1 \
    --adam-beta1 0.9 \
    --adam-beta2 0.95 \
    --adam-eps 1.0e-5 \
    --clip-grad 1.0
"

LEARNING_RATE_ARGS="
    --lr 1.5e-4 \
    --lr-decay-style cosine \
    --lr-warmup-samples 2048000 \
    --min-lr 1.5e-5
"

CHECKPOINTING_ARGS="
    --save-interval 2000 \
    --save $CHECKPOINT_PATH \
    --load $LOAD_CHECKPOINT_PATH
"

LOGGING_ARGS="
    --log-interval 1 \
    --log-memory-to-tensorboard \
    --log-world-size-to-tensorboard \
    --tensorboard-dir $TB_PATH \
    --tensorboard-log-interval 1 \
    --wandb-dir $WB_PATH
"

cmd="torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
              $TRAINING_ARGS \
              $MIXED_PRECISION_ARGS \
              $DATA_ARGS \
              $NETWORK_ARGS \
              $INITIALIZATION_ARGS \
              $REGULARIZATION_ARGS \
              $LEARNING_RATE_ARGS \
              $CHECKPOINTING_ARGS \
              $LOGGING_ARGS
    "
echo $cmd
eval $cmd
